{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) OpenMMLab. All rights reserved.\n",
    "\n",
    "Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
    "\n",
    "## Masked Autoencoders: Visualization Demo\n",
    "\n",
    "This is a visualization demo using our pre-trained MAE models. No GPU is needed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare\n",
    "Check environment. Install packages if in Colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "# check whether run in Colab\n",
    "if 'google.colab' in sys.modules:\n",
    "    print('Running in Colab.')\n",
    "    !pip3 install openmim\n",
    "    !pip install -U openmim\n",
    "    !mim install mmengine\n",
    "    !mim install 'mmcv>=2.0.0rc1'\n",
    "\n",
    "    !git clone https://github.com/open-mmlab/mmselfsup.git\n",
    "    %cd mmselfsup/\n",
    "    !git checkout dev-1.x\n",
    "    !pip install -e .\n",
    "\n",
    "    sys.path.append('./mmselfsup')\n",
    "    %cd demo\n",
    "else:\n",
    "    sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from argparse import ArgumentParser\n",
    "from typing import Tuple, Optional\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from mmengine.dataset import Compose, default_collate\n",
    "\n",
    "from mmselfsup.apis import inference_model, init_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the utils\n",
    "\n",
    "imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
    "imagenet_std = np.array([0.229, 0.224, 0.225])\n",
    "\n",
    "\n",
    "def show_image(img: torch.Tensor, title: str = '') -> None:\n",
    "    # image is [H, W, 3]\n",
    "    assert img.shape[2] == 3\n",
    "\n",
    "    plt.imshow(img)\n",
    "    plt.title(title, fontsize=16)\n",
    "    plt.axis('off')\n",
    "    return\n",
    "\n",
    "\n",
    "def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,\n",
    "                pred_img: torch.Tensor, img_paste: torch.Tensor,\n",
    "                out_file: Optional[str] =None) -> None:\n",
    "    # make the plt figure larger\n",
    "    plt.rcParams['figure.figsize'] = [24, 6]\n",
    "\n",
    "    plt.subplot(1, 4, 1)\n",
    "    show_image(original_img, 'original')\n",
    "\n",
    "    plt.subplot(1, 4, 2)\n",
    "    show_image(img_masked, 'masked')\n",
    "\n",
    "    plt.subplot(1, 4, 3)\n",
    "    show_image(pred_img, 'reconstruction')\n",
    "\n",
    "    plt.subplot(1, 4, 4)\n",
    "    show_image(img_paste, 'reconstruction + visible')\n",
    "\n",
    "    if out_file is None:\n",
    "        plt.show()\n",
    "    else:\n",
    "        plt.savefig(out_file)\n",
    "        print(f'Images are saved to {out_file}')\n",
    "\n",
    "\n",
    "def recover_norm(img: torch.Tensor,\n",
    "                 mean: np.ndarray = imagenet_mean,\n",
    "                 std: np.ndarray = imagenet_std):\n",
    "    if mean is not None and std is not None:\n",
    "        img = torch.clip((img * std + mean) * 255, 0, 255).int()\n",
    "    return img\n",
    "\n",
    "\n",
    "def post_process(\n",
    "    original_img: torch.Tensor,\n",
    "    pred_img: torch.Tensor,\n",
    "    mask: torch.Tensor,\n",
    "    mean: np.ndarray = imagenet_mean,\n",
    "    std: np.ndarray = imagenet_std\n",
    ") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    # channel conversion\n",
    "    original_img = torch.einsum('nchw->nhwc', original_img.cpu())\n",
    "    # masked image\n",
    "    img_masked = original_img * (1 - mask)\n",
    "    # reconstructed image pasted with visible patches\n",
    "    img_paste = original_img * (1 - mask) + pred_img * mask\n",
    "\n",
    "    # muptiply std and add mean to each image\n",
    "    original_img = recover_norm(original_img[0])\n",
    "    img_masked = recover_norm(img_masked[0])\n",
    "\n",
    "    pred_img = recover_norm(pred_img[0])\n",
    "    img_paste = recover_norm(img_paste[0])\n",
    "\n",
    "    return original_img, img_masked, pred_img, img_paste\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load a pre-trained MAE model\n",
    "\n",
    "This is an MAE model trained with config 'mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k.py'.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2022-11-08 11:00:50--  https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
      "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
      "正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
      "已发出 HTTP 请求，正在等待回应... 200 OK\n",
      "长度： 1355429265 (1.3G) [application/octet-stream]\n",
      "正在保存至: “mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth”\n",
      "\n",
      "e_in1k_20220825-cc7  99%[==================> ]   1.26G   913KB/s    剩余 0s    s "
     ]
    }
   ],
   "source": [
    "# download checkpoint if not exist\n",
    "!wget -nc https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "local loads checkpoint from path: mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
      "The model and loaded state dict do not match exactly\n",
      "\n",
      "unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std\n",
      "\n",
      "Model loaded.\n"
     ]
    }
   ],
   "source": [
    "ckpt_path = \"mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\"\n",
    "model = init_model(\n",
    "    '../configs/selfsup/mae/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py',\n",
    "    ckpt_path,\n",
    "    device='cpu')\n",
    "print('Model loaded.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load an image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fb2ccfbac90>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# make random mask reproducible (comment out to make it change)\n",
    "torch.manual_seed(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2022-11-08 11:21:14--  https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
      "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
      "正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
      "已发出 HTTP 请求，正在等待回应... 200 OK\n",
      "长度： 60133 (59K) [image/jpeg]\n",
      "正在保存至: “fox.jpg”\n",
      "\n",
      "fox.jpg             100%[===================>]  58.72K  --.-KB/s    用时 0.05s   \n",
      "\n",
      "2022-11-08 11:21:15 (1.08 MB/s) - 已保存 “fox.jpg” [60133/60133])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget -nc 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_path = 'fox.jpg'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.cfg.test_dataloader = dict(\n",
    "    dataset=dict(pipeline=[\n",
    "        dict(type='LoadImageFromFile'),\n",
    "        dict(type='Resize', scale=(224, 224), backend='pillow'),\n",
    "        dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
    "    ]))\n",
    "\n",
    "vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = dict(img_path=img_path)\n",
    "data = vis_pipeline(data)\n",
    "data = default_collate([data])\n",
    "img, _ = model.data_preprocessor(data, False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reconstruction pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for MAE reconstruction\n",
    "img_embedding = model.head.patchify(img[0])\n",
    "# normalize the target image\n",
    "mean = img_embedding.mean(dim=-1, keepdim=True)\n",
    "std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get reconstruction image\n",
    "features = inference_model(model, img_path)\n",
    "results = model.reconstruct(features, mean=mean, std=std)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "original_target = img[0]\n",
    "original_img, img_masked, pred_img, img_paste = post_process(\n",
    "    original_target,\n",
    "    results.pred.value,\n",
    "    results.mask.value,\n",
    "    mean=mean,\n",
    "    std=std)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Show the image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_images(original_img, img_masked, pred_img, img_paste)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.0 ('openmmlab')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0 (default, Oct  9 2018, 10:31:47) \n[GCC 7.3.0]"
  },
  "vscode": {
   "interpreter": {
    "hash": "5909b3386efe3692f76356628babf720cfd47771f5d858315790cc041eb41361"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
